e3nn repository¶@misc{mario_geiger_2019_3348277,
author = {Mario Geiger and
Tess Smidt and
Wouter Boomsma and
Maurice Weiler and
Michał Tyszkiewicz and
Jes Frellsen and
Benjamin K. Miller and
Josh Rackers},
title = {e3nn/e3nn: Point cloud support},
month = jul,
year = 2019,
doi = {10.5281/zenodo.3348277},
url = {https://doi.org/10.5281/zenodo.3348277}
}
In this tutorial we show how an E3NN network can be used to predict electron densities. One reason this might be a good idea is that electron densities can be represented in a spherical harmonic basis on atom centers. This fits naturally with the E3NN framework.
%load_ext autoreload
%autoreload 2
import numpy as np
import pickle
import torch
import random
from functools import partial
from e3nn.kernel import Kernel
from e3nn.point.operations import Convolution
from e3nn.non_linearities import GatedBlock
from e3nn.non_linearities import rescaled_act
from e3nn.non_linearities.rescaled_act import relu, sigmoid
from e3nn.radial import CosineBasisModel
from e3nn.radial import GaussianRadialModel
torch.set_default_dtype(torch.float64)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
First we load the data. I have saved this in a pickle. For this particular example the dataset is ~6000 water dimer structures with density represented in "a2" density fitting basis set.
Oxygen: 8s, 4p, 4d
Hydrogen: 4s, 1p, 1d
## load density data
picklename = "./density_data/dimer_data.pckl"
with open(picklename, 'rb') as f:
dataset_coeffs, dataset_onehot, dataset_geom, dataset_typemap, Rs_out_list, coeff_by_type = pickle.load(f)
The thing that makes this task tricky is predicting different numbers of spherical harmonics on each atomic center. To address this problem, we introduce a class, Mixer, to handle this.
Our network consists of 3 layers.
## define model
class Mixer(torch.nn.Module):
def __init__(self, Op, Rs_in_s, Rs_out):
super().__init__()
self.ops = torch.nn.ModuleList([
Op(Rs_in, Rs_out)
for Rs_in in Rs_in_s
])
def forward(self, *args, n_norm=1):
# It simply sums the different outputs
y = 0
for m, x in zip(self.ops, args):
y += m(*x, n_norm=n_norm)
return y
class Network(torch.nn.Module):
def __init__(self, Rs_in, Rs_out_list, max_radius=3.0, number_of_basis=3, radial_layers=3, basistype="Gaussian"):
super().__init__()
#sp = rescaled_act.Softplus(beta=5)
#sp = rescaled_act.ShiftedSoftplus(beta=5)
sp = torch.nn.Tanh()
# the [0] is just to get first_layer in stripped form.
# will not work for Rs_in with more than L=0
first_layer = Rs_in[0]
last_shared_layer = (2,1,1)
representations = [first_layer, last_shared_layer]
representations = [[(mul, l) for l, mul in enumerate(rs)] for rs in representations]
if (basistype == 'Gaussian'):
rad_basis = GaussianRadialModel
elif (basistype == 'Cosine'):
rad_basis = CosineBasisModel
else:
print ("Only Gaussian and Cosine Radial basis are currently supported")
RadialModel = partial(rad_basis, max_radius=max_radius,
number_of_basis=number_of_basis, h=100,
L=radial_layers, act=sp)
K = partial(Kernel, RadialModel=RadialModel)
C = partial(Convolution, K)
M = partial(Mixer, C) # wrap C to accept many input types
def make_layer(Rs_in, Rs_out):
act = GatedBlock(Rs_out, sp, sigmoid)
conv = Convolution(K, Rs_in, act.Rs_in)
return torch.nn.ModuleList([conv, act])
self.layers = torch.nn.ModuleList([
make_layer(Rs_layer_in,Rs_layer_out)
for Rs_layer_in, Rs_layer_out in zip(representations[:-1], representations[1:])
])
## set up the split final layer
m = []
for rs in Rs_out_list:
m.append(M([representations[-1], representations[-1]], rs))
# final layer is indexed in order of atom type
self.final_layer = torch.nn.ModuleList([
m[i] for i in range(len(m))
])
def forward(self, input, geometry, atom_type_map):
output = input
batch, N, _ = geometry.shape
for conv, act in self.layers:
output = conv(output, geometry, n_norm=N)
output = act(output)
## split final layer
geometry_list = []
feature_list = []
for i, item in enumerate(atom_type_map):
geometry_list.append(geometry[0][item])
feature_list.append(output[0][item])
## this is assuming that there are only two atom types!
## it should work, though for any arbitrary order of O and H in xyzfile!
featuresO = feature_list[0].unsqueeze(0)
featuresH = feature_list[1].unsqueeze(0)
geometryO = geometry_list[0].unsqueeze(0)
geometryH = geometry_list[1].unsqueeze(0)
final_layer_output = []
for i, layer in enumerate(self.final_layer):
if (i == 0):
final = layer((featuresO, geometryO, geometryO), (featuresH, geometryH, geometryO), n_norm = N)
if (i == 1):
final = layer((featuresO, geometryO, geometryH), (featuresH, geometryH, geometryH), n_norm = N)
final_layer_output.append(final)
# return list of outputO and outputH
output = final_layer_output
return output
Let's initialize a rough model. Here's a brief description of the parameters:
We pass these parameters in as a dictionary so that we can save them for later use if we want to save the model.
Then we send the model to the GPU
The output shows us a helpful schematic of what kinds of operations our network is going to use.
## set arguments to network
maxradius = 3.0
numbasis = 20
radiallayers = 3
radialbasis = "Gaussian"
## set Rs_in based on onehot vector
Rs_in = [(len(dataset_typemap[0]),0)]
print("Rs_in:",Rs_in)
print("\nOxygen Rs_out:",Rs_out_list[0])
print("Hydrogen Rs_out:",Rs_out_list[1])
mydict = {"Rs_in":Rs_in, "Rs_out_list":(Rs_out_list), "max_radius":maxradius,
"number_of_basis":numbasis, "radial_layers":radiallayers,
"basistype":radialbasis}
net = Network(**mydict)
#net.to(device)
From here, training the model looks virtually identical to any other training one might do with a typical neural network in pytorch. In this case we are going to use the Adam optimizer and minibatches.
## set up training
net.train()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
optimizer.zero_grad()
loss_fn = torch.nn.modules.loss.MSELoss()
max_steps = 2000
minibatch_size = 16
print (device)
loss_minibatch = 0
for step in range(max_steps):
i = random.randint(0, len(dataset_geom) - 3001)
onehot = dataset_onehot[i]
points = dataset_geom[i]
atom_type_map = dataset_typemap[i]
coeffs = dataset_coeffs[i]
outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)
outputO = torch.flatten(outputO)
outputH = torch.flatten(outputH)
output = torch.cat((outputO,outputH),0).view(1,1,-1)
loss = loss_fn(output, coeffs)
step_loss = loss.item()
loss.backward()
loss_minibatch += step_loss
if (step+1)%minibatch_size == 0:
optimizer.step()
optimizer.zero_grad()
loss_minibatch = 0
if step % 100 == 0:
print('\nStep {0}, Loss {1}'.format(step, step_loss))
j = random.randint(3000, len(dataset_geom) - 1)
onehot = dataset_onehot[j]
points = dataset_geom[j]*3
atom_type_map = dataset_typemap[j]
coeffs = dataset_coeffs[j]
outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)
outputO = torch.flatten(outputO)
outputH = torch.flatten(outputH)
output = torch.cat((outputO,outputH),0).view(1,1,-1)
loss = loss_fn(output.to(device), coeffs.to(device))
print('\nTest Loss {0}'.format(loss.item()))
Let's check to see if the number of electrons is in the ballpark.
from density_analysis_utils import *
testnumelectrons(net,device,2,"./density_data/a2.gbs",dataset_onehot,dataset_geom,dataset_typemap,coeff_by_type)
To do this we need three components for each function:
First we need to do some data wrangling. The basis set we're using has an annoying property that it has 'SP' functions. This mean one entry that specifies an S function and P function simultaneously.
Last we set up our radial function with the above values.
from density_analysis_utils import *
from e3nn.rs import dim, mul_dim
## define Gaussian Type Orbital basis functions
basis = lambda r, alpha, norm : norm * torch.exp(- alpha * r.unsqueeze(-1) **2)
## get exponent alphas
alphaO, alphaH = get_exponents('./density_data/a2.gbs')
## get normalization constants
normO, normH = parse_whole_normfile('./density_data/a2_norm.dat')
normO = torch.FloatTensor(normO)
normH = torch.FloatTensor(normH)
## get spherical harmonic normalization constants
Rs_out_O = Rs_out_list[0]
Rs_out_H = Rs_out_list[1]
sph_normsO, sph_normsH = get_spherical_harmonic_norms(Rs_out_O,Rs_out_H)
basis_on_r_O = partial(basis, alpha=alphaO, norm=normO)
basis_on_r_H = partial(basis, alpha=alphaH, norm=normH)
assert mul_dim(Rs_out_O) == normO.shape[0]
assert mul_dim(Rs_out_H) == normH.shape[0]
# pick a random structure to test
dimer_num = 4321
onehot = dataset_onehot[dimer_num]
points = dataset_geom[dimer_num]
atom_type_map = dataset_typemap[dimer_num]
outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)
outputO = outputO.data.cpu().numpy()
outputH = outputH.data.cpu().numpy()
from spherical import plot_data_on_grid
import e3nn.o3 as o3
## get the functions
f_list = []
# loop over types
for i, type in enumerate(atom_type_map):
# loop over atoms
for count, atom in enumerate(type):
tot_f = 0
center = points.data.squeeze().numpy()[atom]
# oxygens
if i == 0:
#vsf = VisualizeSphericalFunction(basis_on_r_O, Rs_out_O, o3.spherical_harmonics_xyz)
r, f = plot_data_on_grid(5.0, basis_on_r_O, Rs_out_O,
n=20, center=center)
for j, val in enumerate(outputO.squeeze()[count]):
c = val
norm = sph_normsO[j]
# sum up contributions from every basis function
tot_f += c*f[:,j]/norm
# hydrogens
if i == 1:
#vsf = VisualizeSphericalFunction(basis_on_r_H, Rs_out_H, o3.spherical_harmonics_xyz)
r, f = plot_data_on_grid(5.0, basis_on_r_H, Rs_out_H,
n=20, center=center)
for j, val in enumerate(outputH.squeeze()[count]):
c = val
norm = sph_normsH[j]
# sum up contributions from every basis function
tot_f += c*f[:,j]/norm
f_list.append(tot_f)
all_atom_f = sum(f_list)
print(all_atom_f.max())
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
plot_max = float(all_atom_f.max())
fig = go.Figure(data=go.Volume(
x=r[:,0],
y=r[:,1],
z=r[:,2],
#value=c * f[:, i],
value=all_atom_f,
isomin=-0.005*plot_max,
isomax=0.005*plot_max,
#isomin=-0.03,
#isomax=0.03,
opacity=0.3, # needs to be small to see through all surfaces
opacityscale="uniform",
surface_count=50, # needs to be a large number for good volume rendering
colorscale='RdBu'))
xs = points.data.squeeze().numpy()[:,0]
ys = points.data.squeeze().numpy()[:,1]
zs = points.data.squeeze().numpy()[:,2]
fig.add_scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=12,color='Black',opacity=1.0))
fig.show()